/*
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#ifndef EXAMPLES_MATRIX_MATMUL_ASYNC_CUSTOM_IMPL_H
#define EXAMPLES_MATRIX_MATMUL_ASYNC_CUSTOM_IMPL_H
#include "kernel_operator.h"
#include "lib/matmul_intf.h"

__aicore__ inline uint32_t Ceiling(uint32_t a, uint32_t b)
{
    return (a + b - 1) / b;
}

template <typename aType, typename bType, typename cType, typename biasType, AscendC::TPosition cPos>
class MatmulAsyncKernel {
    public:
        __aicore__ inline MatmulAsyncKernel(){};
        __aicore__ inline void Init(GM_ADDR a, GM_ADDR b, GM_ADDR bias, GM_ADDR c, GM_ADDR workspace, const TCubeTiling& tiling);
        template <bool hasBias = false>
        __aicore__ inline void Process(AscendC::TPipe* pipe);
        AscendC::Matmul<
            AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, aType>,
            AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, bType>,
            AscendC::MatmulType<cPos, CubeFormat::ND, cType>,
            AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, biasType>, CFG_MDL> matmulObj;

    private:
        __aicore__ inline void CalcOffset(int32_t blockIdx, const TCubeTiling& tiling, int32_t& offsetA, int32_t& offsetB,
                                          int32_t& offsetC, int32_t& offsetBias);
        __aicore__ inline uint32_t CalcDstOffset(uint32_t i);

        AscendC::GlobalTensor<aType> aGlobal;
        AscendC::GlobalTensor<bType> bGlobal;
        AscendC::GlobalTensor<cType> cGlobal;
        AscendC::GlobalTensor<biasType> biasGlobal;
        AscendC::GlobalTensor<cType> workspaceGlobal;
        AscendC::TQue<AscendC::TPosition::VECIN, 1> cInQueue;
        AscendC::TQue<AscendC::TPosition::VECOUT, 1> cOutQueue;
        TCubeTiling tiling;
};

template <typename aType, typename bType, typename cType, typename biasType, AscendC::TPosition cPos>
__aicore__ inline void MatmulAsyncKernel<aType, bType, cType, biasType, cPos>::Init(
    GM_ADDR a, GM_ADDR b, GM_ADDR bias, GM_ADDR c, GM_ADDR workspace, const TCubeTiling& tiling)
{
    this->tiling = tiling;
    aGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ aType*>(a), tiling.M * tiling.Ka);
    bGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ bType*>(b), tiling.Kb * tiling.N);
    cGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ cType*>(c), tiling.M * tiling.N);
    biasGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ biasType*>(bias), tiling.N);
    workspaceGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ cType*>(workspace), tiling.M * tiling.N);

    int32_t offsetA = 0;
    int32_t offsetB = 0;
    int32_t offsetC = 0;
    int32_t offsetBias = 0;
    CalcOffset(AscendC::GetBlockIdx(), tiling, offsetA, offsetB, offsetC, offsetBias);
    aGlobal = aGlobal[offsetA];
    bGlobal = bGlobal[offsetB];
    cGlobal = cGlobal[offsetC];
    biasGlobal = biasGlobal[offsetBias];
    workspaceGlobal = workspaceGlobal[AscendC::GetBlockIdx() * tiling.singleCoreM * tiling.singleCoreN];
    if(GetSysWorkSpacePtr() == nullptr){
        return;
    }
}

template <typename aType, typename bType, typename cType, typename biasType, AscendC::TPosition cPos>
template <bool hasBias>
__aicore__ inline void MatmulAsyncKernel<aType, bType, cType, biasType, cPos>::Process(AscendC::TPipe* pipe)
{
    matmulObj.SetTensorA(aGlobal);
    matmulObj.SetTensorB(bGlobal);
    if constexpr (hasBias) {
        matmulObj.SetBias(biasGlobal);
    }

    if constexpr (cPos == AscendC::TPosition::GM) {
        matmulObj.template IterateAll<false>(cGlobal, 0, false, true);
        matmulObj.WaitIterateAll();
    } else if constexpr (cPos == AscendC::TPosition::VECIN) {
        matmulObj.SetWorkspace(workspaceGlobal);
        matmulObj.template Iterate<false>();

        pipe->InitBuffer(cInQueue, 1, tiling.baseM * tiling.baseN * sizeof(cType));
        pipe->InitBuffer(cOutQueue, 1, tiling.baseM * tiling.baseN * sizeof(cType));

        AscendC::DataCopyParams copyParams = {
            (uint16_t)tiling.baseM,
            (uint16_t)(tiling.baseN * sizeof(cType) / AscendC::DEFAULT_C0_SIZE),
            (uint16_t)0,
            (uint16_t)((tiling.N - tiling.baseN) * sizeof(cType) / AscendC::DEFAULT_C0_SIZE)
        };
        uint32_t iterateTimes = Ceiling(tiling.singleCoreM, tiling.baseM) * Ceiling(tiling.singleCoreN, tiling.baseN);
        for (uint32_t i = 0; i < iterateTimes; ++i) {
            // compute
            auto cInLocal = cInQueue.AllocTensor<cType>();
            matmulObj.template GetTensorC<false>(cInLocal);
            cInQueue.EnQue(cInLocal);

            // any vector operator
            auto src = cInQueue.DeQue<cType>();
            auto dst = cOutQueue.AllocTensor<cType>();
            DataCopy(dst, src, tiling.baseM * tiling.baseN);
            cOutQueue.EnQue(dst);
            cInQueue.FreeTensor(src);

            // copy out
            auto cOutLocal = cOutQueue.DeQue<cType>();
            DataCopy(cGlobal[CalcDstOffset(i)], cOutLocal, copyParams);
            cOutQueue.FreeTensor(cOutLocal);
        }
    }
    matmulObj.End();
}

template <typename aType, typename bType, typename cType, typename biasType, AscendC::TPosition cPos>
__aicore__ inline void MatmulAsyncKernel<aType, bType, cType, biasType, cPos>::CalcOffset(int32_t blockIdx,
    const TCubeTiling& tiling, int32_t& offsetA, int32_t& offsetB, int32_t& offsetC, int32_t& offsetBias)
{
    auto mSingleBlocks = Ceiling(tiling.M, tiling.singleCoreM);
    auto mCoreIndx = blockIdx % mSingleBlocks;
    auto nCoreIndx = blockIdx / mSingleBlocks;

    offsetA = mCoreIndx * tiling.Ka * tiling.singleCoreM;
    offsetB = nCoreIndx * tiling.singleCoreN;
    offsetC = mCoreIndx * tiling.N * tiling.singleCoreM + nCoreIndx * tiling.singleCoreN;
    offsetBias = nCoreIndx * tiling.singleCoreN;

    // process with tail block
    int32_t tailM = tiling.M - mCoreIndx * tiling.singleCoreM;
    tailM = tailM < tiling.singleCoreM ? tailM : tiling.singleCoreM;
    int32_t tailN = tiling.N - nCoreIndx * tiling.singleCoreN;
    tailN = tailN < tiling.singleCoreN ? tailN : tiling.singleCoreN;
    if (tailM < tiling.singleCoreM || tailN < tiling.singleCoreN) {
        matmulObj.SetTail(tailM, tailN);
    }
}

template <typename aType, typename bType, typename cType, typename biasType, AscendC::TPosition cPos>
__aicore__ inline uint32_t MatmulAsyncKernel<aType, bType, cType, biasType, cPos>::CalcDstOffset(uint32_t i)
{
    uint32_t mIter = 0;
    uint32_t nIter = 0;
    if (tiling.iterateOrder != 1) {
        uint32_t mIterTimes = Ceiling(tiling.singleCoreM, tiling.baseM);
        mIter = i % mIterTimes;
        nIter = i / mIterTimes;
    } else {
        uint32_t nIterTimes = Ceiling(tiling.singleCoreN, tiling.baseN);
        mIter = i / nIterTimes;
        nIter = i % nIterTimes;
    }
    return (mIter * tiling.baseM * tiling.N + nIter * tiling.baseN);
}
#endif // EXAMPLES_MATRIX_MATMUL_ASYNC_CUSTOM_IMPL_H